import os
import sys


target = sys.argv[1]

if target in ('chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'histone', 'RPPH', 'snoRNA', 'scaRNA', 'RMRP', 'yRNA', 'snar', 'vRNA', 'TERC', 'MALAT1', 'snhg'):
    score_threshold = 0.8
elif target in ('mRNA', 'lncRNA', 'gencode', 'fantomcat', 'novel', 'genome'):
    score_threshold = 0.9
else:
    raise Exception("Unknown target %s" % target)


filenames = []
for filename in os.listdir("."):
    terms = filename.split(".")
    if len(terms)!=3:
        continue
    if terms[0]!=target:
        continue
    if terms[2]!='psl':
        continue
    index = int(terms[1])
    filenames.append(filename)

def keyfunction(filename):
    terms = filename.split(".")
    assert len(terms) == 3
    assert terms[0]==target
    assert terms[2]=='psl'
    index = int(terms[1])
    return index

filenames.sort(key=keyfunction)
n = len(filenames)


if target in ("mRNA", "lncRNA", "gencode", "fantomcat", "genome"):
    if target == "mRNA":
        letters = 471870682
        sequences = 114043
    elif target == "lncRNA":
        letters = 38900666
        sequences = 15572
    elif target == "gencode":
        letters = 360641787
        sequences = 228048
    elif target == "fantomcat":
        letters = 2151827432
        sequences = 709176
    elif target == "genome":
        letters = 3209286105
        sequences = 455
    checked = 0
    failed = 0
    errors = 0
    for i in range(n):
        filename = "script_%s_%d.stderr" % (target, i)
        print("Checking error file %s" % filename)
        with open(filename) as handle:
            for line in handle:
                for prefix in ("[M::bwa_idx_load_from_disk]", "[M::process]", "[M::mem_process_seqs]", "[main]"):
                    if line.startswith(prefix):
                        break
                else:
                    raise Exception(line)
        filename = "script_%s_%d.stdout" % (target, i)
        print("Checking output file %s" % filename)
        with open(filename) as handle:
            assert not handle.read()
        filename = "%s.%d.out" % (target, i)
        print("Checking pslCheck output file %s" % filename)
        with open(filename) as handle:
            for line in handle:
                pass
            # Check the last line only; previous lines contain detailed
            # information on alignments going beyond tSize.
            if line == "stdin is empty\n":
                pass
            else:
                words = line.split()
                assert len(words) == 6
                assert words[0] == "checked:"
                assert words[2] == "failed:"
                assert words[4] == "errors:"
                checked += int(words[1])
                failed += int(words[3])
                errors += int(words[5])
    assert failed == 0
    assert errors == 0
    print("Total checked: %d failed: %d errors: %d" % (checked, failed, errors))
else:
    for i in range(n):
        filename = "script_%s_%d.stdout" % (target, i)
        print("Checking output file %s" % filename)
        with open(filename) as handle:
            for line in handle:
                pass
            assert line.strip() == 'Done'
        filename = "script_%s_%d.stderr" % (target, i)
        print("Checking error file %s" % filename)
        with open(filename) as handle:
            assert not handle.read()


header = """\
psLayout version 3

match	mis- 	rep. 	N's	Q gap	Q gap	T gap	T gap	strand	Q        	Q   	Q    	Q  	T        	T   	T    	T  	block	blockSizes 	qStarts	 tStarts
     	match	match	   	count	bases	count	bases	      	name     	size	start	end	name     	size	start	end	count
---------------------------------------------------------------------------------------------------------------------------------------------------------------
"""

output_filename = "%s.psl" % target
print("Writing", output_filename)
output = open(output_filename, 'w')
output.write(header)
total = 0
for number, filename in enumerate(filenames):
    terms = filename.split(".")
    assert len(terms)==3
    assert terms[0]==target
    assert terms[2]=='psl'
    if number != int(terms[1]):
        output.close()
        os.remove(output_filename)
        raise Exception("Missing output file starting at %d" % current)
    print("Reading", filename)
    handle = open(filename)
    current = None
    selected_lines = []
    for line in handle:
        words = line.split()
        name = words[9]
        matches = int(words[0])
        misMatches = int(words[1])
        qBaseInsert = int(words[5])
        tBaseInsert = int(words[7])
        qSize = int(words[10])
        qStart = int(words[11])
        qEnd = int(words[12])
        score = matches - misMatches - qBaseInsert - tBaseInsert - qStart - (qSize - qEnd)
        tSize = int(words[14])
        tStart = int(words[15])
        tEnd = int(words[16])
        extent = tEnd - tStart
        if name != current:
            for selected_line in selected_lines:
                output.write(selected_line)
            current = name
            selected_lines = []
            maximum_score = score_threshold * qSize
            total += 1
        if score < maximum_score:
            continue
        elif score > maximum_score:
            maximum_score = score
            shortest_extent = extent
            shortest_tSize = tSize
            selected_lines.clear()
        else:
            assert score == maximum_score
            if extent > shortest_extent:
                continue
            elif extent < shortest_extent:
                shortest_extent = extent
                shortest_tSize = tSize
                selected_lines.clear()
            else:
                assert extent == shortest_extent
                if target != 'genome':
                    if tSize > shortest_tSize:
                        continue
                    elif tSize < shortest_tSize:
                        shortest_tSize = tSize
                        selected_lines.clear()
                    else:
                        assert tSize == shortest_tSize
        selected_lines.append(line)
    for selected_line in selected_lines:
        output.write(selected_line)
    handle.close()
output.close()
print("%d sequences were mapped" % total)
